[executorch][cuda] fuse gate/up MLP projections #20482
Conversation
Three CUDA-export memory optimizations: - tq4_sdpa: add BLOCK_N=16 (and a BLOCK_M=32) autotune config. The superset is kept for big-shared-memory GPUs (A100/H100); the Triton autotuner auto-prunes configs that exceed a GPU's shared memory (OutOfResources -> inf), so the same config list also works on the 5090 (Blackwell, ~101 KB SMEM) where the previous smallest config did not fit. - int4_dispatch: chunk the inline _dequant_matmul along N for vocab-sized weights (N>65536, i.e. only the lm_head). Avoids transiently materializing the full ~10 GiB bf16 lm_head when AOTI executes the int4_plain_mm custom op during autotune / cpp_wrapper. The runtime decode path uses the C++ dp4a shim and the M>4 prefill inline path is below the threshold, so this never enters the runtime graph -> zero runtime / accuracy impact. Applied unconditionally (no flag). - cuda_backend / aoti_backend: skip occupying the GPU with the KV-cache buffers during AOTI compile (gated behind low_memory_mode). A new move_program_to_device hook places KV constants on the target device but immediately frees their storage (resize_(0)), so the fake-tensor device check passes while no real KV bytes sit on the GPU during autotune. The emptied buffers are re-synthesized as zeros at the _unlift_graph clone and at serialization, and excluded from constant dedup (resize_(0) gives every KV data_ptr 0, which would otherwise collapse same-shape caches across layers). Result on 2xA100: Gemma4-31B @128k no-TQ export peak 36.3 -> 27.0 GiB; the exported model runs correctly (output "...Paris.").
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20482
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 1 Pending, 4 Unrelated Failures, 1 Unclassified FailureAs of commit 4025660 with merge base 1b726b2 ( NEW FAILURES - The following jobs have failed:
UNCLASSIFIED FAILURE - DrCI could not classify the following job because the workflow did not run on the merge base. The failure may be pre-existing on trunk or introduced by this PR:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
|
This PR needs a
|
8b145b5 to
1c371e2
Compare
…nly frees genuinely all-zero kv_cache.* buffers (count_nonzero==0); preserves TQ4 centroids/boundaries/rotation/rotation_T
Summary: Fuse each gemma4_31b MLP's gate_proj|up_proj into a single [2*intermediate, hidden] coalesced-int4 matmul, applied by default in the CUDA export. This issues one activation-quant + one W4A8 matvec per layer instead of two, cutting per-token launch + activation-quant overhead in the launch-bound decode path. Only Q4_K (CudaCoalescedInt4Tensor) gate/up pairs are fused; any other quant type (e.g. Q6_K) is left as two matmuls (guarded, still correct). Builds on the already-landed kv_len-bounded tq4_sdpa kernel + gemma4_31b call-site (kv_len + mask_is_causal), which recovered 128k decode from ~2.8 to ~43 tok/s. With both, ET gemma4_31b 128k+TurboQuant decode beats llama.cpp at every measured context (cuda_graph ON): ctx ET llama 512 44.80 42.77 2K 43.20 41.97 8K 42.23 41.23 32K 41.64 40.27 127K 38.41 35.97 TurboQuant KV compression kept; prefill restored (6-8x) with no regression; output quality preserved. Test Plan: - Fusion numerics: fused vs unfused MLP through the real W4A8 int4_plain_mm kernel = bit-exact (max_abs_diff 0.0, cos 1.000000) for decode (T=1) and prefill (T=4). - Export + run: fused module exported via CudaPartitioner and executed through executor_runner (RC=0, cos 0.999915 vs eager). Full 31B export logs "Fused gate+up on 60 MLP layers". - Decode A/B (gemma4_31b 128k+TQ, cuda_graph ON, 5x median): table above; beats llama.cpp at 512 -> 127K. nsys: tq4_sdpa 91.7% -> 2.9% of decode.
1c371e2 to
4025660
Compare
Summary:
Fuse each gemma4_31b MLP's gate_proj|up_proj into a single [2*intermediate, hidden] coalesced-int4 matmul, applied by default in the CUDA export. This issues one activation-quant + one W4A8 matvec per layer instead of two, cutting per-token launch + activation-quant overhead in the launch-bound decode path. Only Q4_K (CudaCoalescedInt4Tensor) gate/up pairs are fused; any other quant type (e.g. Q6_K) is left as two matmuls (guarded, still correct).
Next Step: we will upsteam this kind of operator fusion into gemma4-31b model level when loading gguf. #20481 is the draft PR